
import numpy as np


from src.summary.Summary import Summary

from src.clustering.utils import get_average_support
from src.clustering.base_distances_bis import (
    jaccard,
    geom_jaccard,
    norm_hamming,
    harmonic_jaccard,
    root_jaccard,
    rms_jaccard,
    geom_hamming,
    rms_hamming,
)

DISTANCES = {
    'jaccard': jaccard,
    'geom_jaccard': geom_jaccard,
    'norm_hamming': norm_hamming,
    'harm_jaccard': harmonic_jaccard,
    'root_jaccard': root_jaccard,
    'rms_jaccard': rms_jaccard,
    'geom_hamming': geom_hamming,
    'rms_hamming': rms_hamming
}

def find_min_in_upper_right(dict_of_dicts):
    # find the smallest value in dict of dicts where the first key is smaller than the second key
    min_index = None
    min_value = None
    for i in dict_of_dicts:
        for j in dict_of_dicts[i]:
            if i < j and (min_value is None or dict_of_dicts[i][j] < min_value):
                min_index = (i, j)
                min_value = dict_of_dicts[i][j]
            # if i < j and dict_of_dicts[i][j] == min_value:
            #     print("tie", min_value)
    return min_index, min_value

def _symmetric_hierarchical(
        P, num_v_clusters, num_c_clusters, saveas=None, func=None, base_distance=None):

    # focus only on c_clusters
    base_distance = DISTANCES[base_distance]

    candidate_distance_dict = {i: {j: base_distance(P[:, i], P[:, j]) / P.shape[0]
                                   for j in range(P.shape[1])} for i in range(P.shape[1])}
    candidate_weight = np.ones(P.shape[1])

    # id of cluster will be the youngest member of the cluster
    voter_cluster = {i: {i} for i in range(P.shape[0])}
    candidate_cluster = {i: {i} for i in range(P.shape[1])}

    # print(candidate_distance_dict)
    # create copy of P that is a dict of dicts

    while not len(candidate_cluster) == num_c_clusters:

        c_min_index, c_min_value = find_min_in_upper_right(candidate_distance_dict)

        i, j = sorted(c_min_index)

        # Merge the two candidates
        # Update candidate_distance matrix for the merged column
        for k in candidate_cluster:
            if k != i and candidate_cluster[k] is not None:

                tmp_distances = []
                for l in candidate_cluster[i]:
                    for m in candidate_cluster[k]:
                        tmp_distances.append(base_distance(P[:, l], P[:, m]) / P.shape[0])

                candidate_distance_dict[min(i, k)][max(i, k)] = func(tmp_distances)

        # add to cluster i the cluster of j
        candidate_cluster[i] = candidate_cluster[i].union(candidate_cluster[j])
        # remove item j from the dictionary
        del candidate_cluster[j]

        candidate_weight[i] += candidate_weight[j]

        # Remove the row and column corresponding to the deleted candidate
        for k in candidate_distance_dict:
            del candidate_distance_dict[k][j]
        del candidate_distance_dict[j]

    # convert voter_cluster and candidate_cluster to a mapping
    v_mapping = np.zeros(P.shape[0], dtype=int)
    c_mapping = np.zeros(P.shape[1], dtype=int)
    for i, cluster_id in enumerate(voter_cluster):
        for voter in voter_cluster[cluster_id]:
            v_mapping[voter] = i
    for i, cluster_id in enumerate(candidate_cluster):
        for candidate in candidate_cluster[cluster_id]:
            c_mapping[candidate] = i

    avg_support = get_average_support(P, v_mapping, c_mapping, num_v_clusters, num_c_clusters)

    summary = Summary(avg_support, v_mapping, c_mapping)
    return summary


def _asymmetric_hierarchical(P, num_v_clusters, num_c_clusters, saveas=None, func=None):
    steps = []

    voter_distance_dict = {i: {j: np.sum(np.abs(P[i] - P[j])) / P.shape[1]
                               for j in range(P.shape[0])} for i in range(P.shape[0])}
    candidate_distance_dict = {i: {j: np.sum(np.abs(P[:, i] - P[:, j])) / P.shape[0]
                                   for j in range(P.shape[1])} for i in range(P.shape[1])}

    voter_weight = np.ones(P.shape[0])
    candidate_weight = np.ones(P.shape[1])

    # id of cluster will be the youngest member of the cluster
    voter_cluster = {i: {i} for i in range(P.shape[0])}
    candidate_cluster = {i: {i} for i in range(P.shape[1])}

    # create copy of P that is a dict of dicts
    P_dict = {i: {j: P[i][j] for j in range(P.shape[1])} for i in range(P.shape[0])}

    while not (len(voter_cluster) == num_v_clusters and len(candidate_cluster) == num_c_clusters):

        v_min_value = np.inf
        c_min_value = np.inf

        if len(voter_cluster) > num_v_clusters:
            v_min_index, v_min_value = find_min_in_upper_right(voter_distance_dict)
        if len(candidate_cluster) > num_c_clusters:
            c_min_index, c_min_value = find_min_in_upper_right(candidate_distance_dict)

        steps.append((P.shape[0], P.shape[1]))

        if v_min_value <= c_min_value:
            i, j = v_min_index
            # Merge the two voters
            w1 = voter_weight[i]
            w2 = voter_weight[j]

            # Update P_dict for the merged rows
            for k in candidate_cluster:
                if candidate_cluster[k] is not None:
                    P_dict[i][k] = (P_dict[i][k]*w1 + P_dict[j][k]*w2) / (w1+w2)

            # Update voter_distance matrix for the merged row
            for k in voter_cluster:
                if k != i and voter_cluster[k] is not None:

                    tmp_distances = []
                    for l in voter_cluster[i]:
                        for m in voter_cluster[k]:
                            tmp_distances.append(np.sum(np.abs(P[l] - P[m])) / P.shape[1])
                    voter_distance_dict[min(i, k)][max(i, k)] = func(tmp_distances)

            # add to cluster i the cluster of j
            voter_cluster[i] = voter_cluster[i].union(voter_cluster[j])
            # remove key j from the dictionary
            del voter_cluster[j]

            voter_weight[i] += voter_weight[j]

            # remove the row and column corresponding to the deleted voter
            for k in voter_distance_dict:
                del voter_distance_dict[k][j]
            del voter_distance_dict[j]

            # remove the corresponding row to the deleted voter from P_dict
            del P_dict[j]

        else:
            i, j = c_min_index
            # Merge the two candidates
            w1 = candidate_weight[i]
            w2 = candidate_weight[j]

            # Update P_dict for the merged columns
            for k in voter_cluster:
                if voter_cluster[k] is not None:
                    P_dict[k][i] = (P_dict[k][i]*w1 + P_dict[k][j]*w2) / (w1+w2)

            # Update candidate_distance matrix for the merged column
            for k in candidate_cluster:
                if k != i and candidate_cluster[k] is not None:

                    tmp_distances = []
                    for l in candidate_cluster[i]:
                        for m in candidate_cluster[k]:
                            tmp_distances.append(np.sum(np.abs(P[:, l] - P[:, m])) / P.shape[0])

                    candidate_distance_dict[min(i, k)][max(i, k)] = func(tmp_distances)

            # add to cluster i the cluster of j
            candidate_cluster[i] = candidate_cluster[i].union(candidate_cluster[j])
            # remove item j from the dictionary
            del candidate_cluster[j]

            candidate_weight[i] += candidate_weight[j]

            # Remove the row and column corresponding to the deleted candidate
            for k in candidate_distance_dict:
                del candidate_distance_dict[k][j]
            del candidate_distance_dict[j]

            # remove the corresponding column to the deleted candidate from P_dict
            for k in P_dict:
                del P_dict[k][j]

    # convert P_dict to a matrix R skipping empty rows and columns
    R = np.zeros((len(voter_cluster), len(candidate_cluster)))
    for i, row in enumerate(voter_cluster):
        for j, col in enumerate(candidate_cluster):
            R[i, j] = P_dict[row][col]

    # convert voter_cluster and candidate_cluster to a mapping
    v_mapping = np.zeros(P.shape[0], dtype=int)
    c_mapping = np.zeros(P.shape[1], dtype=int)
    for i, cluster_id in enumerate(voter_cluster):
        for voter in voter_cluster[cluster_id]:
            v_mapping[voter] = i
    for i, cluster_id in enumerate(candidate_cluster):
        for candidate in candidate_cluster[cluster_id]:
            c_mapping[candidate] = i

    avg_support = get_average_support(P, v_mapping, c_mapping, num_v_clusters, num_c_clusters)

    summary = Summary(avg_support, v_mapping, c_mapping)
    return summary


def hierarchical_minavg(P, num_v_clusters, num_c_clusters, saveas=None, base_distance=None):
    return _symmetric_hierarchical(P, num_v_clusters, num_c_clusters, saveas=saveas,
                                   func=np.mean, base_distance=base_distance)


def hierarchical_minmax(P, num_v_clusters, num_c_clusters, saveas=None, base_distance=None):
    return _symmetric_hierarchical(P, num_v_clusters, num_c_clusters, saveas=saveas,
                                   func=np.max, base_distance=base_distance)


def hierarchical_minmin(P, num_v_clusters, num_c_clusters, saveas=None, base_distance=None):
    return _symmetric_hierarchical(P, num_v_clusters, num_c_clusters, saveas=saveas,
                                   func=np.min, base_distance=base_distance)